(in-package #:coalton-native-tests)

(coalton-toplevel

  (define-type-alias Array array:LispArray)

  (declare dft/c64 (Array (Complex F32) -> Array (Complex F32)))
  (define (dft/c64 v)
    (let ((n (array:length v))
          (res (array:make n 0)))
      (experimental:dotimes (i n)
        (let ((x (experimental:sumtimes (j n)
                   (* (math:cis (* (* -2 math:pi)
                                   (math:general/ (math:toInteger (* i j))
                                                  (math:toInteger n))))
                      (array:aref v j)))))
          (array:set! res i x)))
      res))

  (declare idft/c64 (Array (Complex F32) -> Array (Complex F32)))
  (define (idft/c64 v)
    (let ((n (array:length v))
          (res (array:make n 0)))
      (experimental:dotimes (i n)
        (let ((x (/ (experimental:sumtimes (j n)
                      (* (math:cis (* (* 2 math:pi)
                                      (math:general/ (math:toInteger (* i j))
                                                     (math:toInteger n))))
                         (array:aref v j)))
                    (math:integral->num n))))
          (array:set! res i x)))
      res))

  (declare dft/c128 (Array (Complex F64) -> Array (Complex F64)))
  (define (dft/c128 v)
    (let ((n (array:length v))
          (res (array:make n 0)))
      (experimental:dotimes (i n)
        (let ((x (experimental:sumtimes (j n)
                   (* (math:cis (* (* -2 math:pi)
                                   (math:general/ (math:toInteger (* i j))
                                                  (math:toInteger n))))
                      (array:aref v j)))))
          (array:set! res i x)))
      res))

  (declare idft/c128 (Array (Complex F64) -> Array (Complex F64)))
  (define (idft/c128 v)
    (let ((n (array:length v))
          (res (array:make n 0)))
      (experimental:dotimes (i n)
        (let ((x (/ (experimental:sumtimes (j n)
                      (* (math:cis (* (* 2 math:pi)
                                      (math:general/ (math:toInteger (* i j))
                                                     (math:toInteger n))))
                         (array:aref v j)))
                    (math:integral->num n))))
          (array:set! res i x)))
      res))

  (define (complex-array-magnitude v)
    (sqrt (experimental:sumtimes (i (array:length v))
            (let ((x (array:aref v i)))
              (math:square-magnitude x)))))

  (define (%diff x y)
    (* 200 (/ (abs (- x y))
              (+ x y))))

  (define (is-complex-array~= v1 v2)
    (let ((n1 (array:length v1))
          (n2 (array:length v2)))
      (is (== n1 n2) "Unequal lengths.")
      (let ((mag1 (complex-array-magnitude v1))
            (mag2 (complex-array-magnitude v2)))
        (is (< (%diff mag1 mag2) (/ 1 10)) "Dissimilar magnitudes.")
        (experimental:dotimes (i n1)
          (let ((x1 (/ (array:aref v1 i) (complex mag1 0)))
                (x2 (/ (array:aref v2 i) (complex mag2 0))))
            (is (< (math:magnitude (- x1 x2)) (/ 1 1000))))))))

  (define (make-sample-array n f)
    (let ((v (array:make n 0)))
      (experimental:dotimes (i n)
        (let ((x (math:cis (* f (* (* 2 math:pi) (math:integral->num i))))))
          (array:set! v i x)))
      v)))

(define-test fft-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (dft/c64 input))
        (actual-output (the (Array (Complex F32)) (fft:fft input))))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (dft/c128 input))
        (actual-output (the (Array (Complex F64)) (fft:fft input))))
    (is-complex-array~= actual-output expected-output)))

(define-test ifft-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (idft/c64 input))
        (actual-output (the (Array (Complex F32)) (fft:ifft input))))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (idft/c128 input))
        (actual-output (the (Array (Complex F64)) (fft:ifft input))))
    (is-complex-array~= actual-output expected-output)))

(define-test fft!-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (dft/c64 input))
        (actual-output (fft:fft! (array:copy input))))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (dft/c128 input))
        (actual-output (fft:fft! (array:copy input))))
    (is-complex-array~= actual-output expected-output)))

(define-test ifft!-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (idft/c64 input))
        (actual-output (fft:ifft! (array:copy input))))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (idft/c128 input))
        (actual-output (fft:ifft! (array:copy input))))
    (is-complex-array~= actual-output expected-output)))

(define-test fft-into!-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (dft/c64 input))
        (dst (the (Array (Complex F32)) (array:make-uninitialized 32)))
        (actual-output (fft:fft-into! dst input)))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (dft/c128 input))
        (dst (the (Array (Complex F64)) (array:make-uninitialized 32)))
        (actual-output (fft:fft-into! dst input)))
    (is-complex-array~= actual-output expected-output)))

(define-test ifft-into!-test ()
  (let ((input (make-sample-array 32 0.25f0))
        (expected-output (idft/c64 input))
        (dst (the (Array (Complex F32)) (array:make-uninitialized 32)))
        (actual-output (fft:ifft-into! dst input)))
    (is-complex-array~= actual-output expected-output))

  (let ((input (make-sample-array 32 0.25d0))
        (expected-output (idft/c128 input))
        (dst (the (Array (Complex F64)) (array:make-uninitialized 32)))
        (actual-output (fft:ifft-into! dst input)))
    (is-complex-array~= actual-output expected-output)))
